import asyncio
import socket
import os

from datetime import datetime

from py_pli.pylib import VUnits
from py_pli.pylib import Measurements
from py_pli.pylib import GlobalVar

import config_enum.excitationlight_selector_enum as els_config
import config_enum.detector_aperture_slider as as_config
import config_enum.scan_table_enum as st_config

from meas_services.instrument import InstrumentService

from virtualunits.HAL import HAL
from virtualunits.meas_seq_generator import meas_seq_generator
from virtualunits.meas_seq_generator import OutputSignal
from virtualunits.meas_seq_generator import MeasurementChannel

from fleming.common.firmware_util import send_gc_msg

instrument: InstrumentService = Measurements.instance.instrument

hal_unit: HAL = VUnits.instance.hal
meas_unit = hal_unit.measurementUnit
els_unit = hal_unit.excitationLightSelector
fms_unit = hal_unit.filterModuleSlider
fm_unit = hal_unit.focusMover
st_unit = hal_unit.scan_table
as1_unit = hal_unit.detectorApertureSlider1
as2_unit = hal_unit.detectorApertureSlider2


async def alpha_scan_init(filter_id=5001, plate="'96 OptiPlate (Black)'"):
    await send_gc_msg(f"Initializing Alpha Scan")
    await instrument.InitializeInstrument()

    st_unit.SetPlateType(plate)
    st_unit.SetCurrentMeasPosition(st_config.GC_Params.FBDTop_TopLeftCorner)
    
    await els_unit.GotoPosition(els_config.Positions.Alpha)
    await fms_unit.SelectModuleWithId(filter_id)
    await fm_unit.GotoPlateHeight()

    no_of_wells = st_unit.get_current_plate_dimensions().WellsX * st_unit.get_current_plate_dimensions().WellsY
    if no_of_wells <= 96:
        as_pos = as_config.Positions.Ap_3_0
    else:
        as_pos = as_config.Positions.Ap_1_6
    await as1_unit.GotoPosition(as_pos)

    op_id = 'alpha_excitation_off'
    seq_gen = meas_seq_generator()
    seq_gen.SetSignals(OutputSignal.Alpha)
    seq_gen.Stop()
    meas_unit.ClearOperations()
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)
    await meas_unit.ExecuteMeasurement(op_id)

    await meas_unit.EnableAlphaLaserPower(True)

    return f"alpha_scan_init() done"
    

async def alpha_scan(col, row, iterations=10, exc_time_ms=100.0, det_time_ms=100.0, window_ms=1.0, include_ext=0, hv_gating=1):
    iterations = iterations if iterations else 10
    exc_time_ms = exc_time_ms if exc_time_ms else 100.0
    det_time_ms = det_time_ms if det_time_ms else 100.0
    window_ms = window_ms if window_ms else 1.0
    include_ext = include_ext if include_ext else 0
    hv_gating = hv_gating if hv_gating else 1

    delay = 0.1

    window = round(window_ms * 1e5)
    exc_window_count = round(exc_time_ms * 1e5 / window)
    det_window_count = round(det_time_ms * 1e5 / window)

    await send_gc_msg(f"Starting alpha_scan(iterations={iterations}, exc_time_ms={exc_time_ms}, det_time_ms={det_time_ms}, window_ms={window_ms}, include_ext={include_ext}, hv_gating={hv_gating})")
    
    GlobalVar.set_stop_gc(False)

    await st_unit.MoveToWell(col, row)

    instrument = socket.gethostname()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    report_dir = f"{os.path.dirname(__file__)}/alpha_scan_results"
    os.makedirs(report_dir, exist_ok=True)
    with open(f"{report_dir}/alpha_scan__{instrument}_{timestamp}.csv", 'w') as file:
        file.write(f"alpha_scan(iterations={iterations}, exc_time_ms={exc_time_ms}, det_time_ms={det_time_ms}, window_ms={window_ms}, include_ext={include_ext}, hv_gating={hv_gating}) on {instrument} started at {timestamp}\n")
        file.write(f"time [ms]")
        if not include_ext:
            time_range = range(1, (det_window_count + 1))
        else:
            time_range = range((1 - exc_window_count), (det_window_count + 1))
        for time_us in time_range:
            file.write(f" ; {(time_us / 1000):8.3f}")
        file.write(f"\n")
        for i in range(iterations):
            if GlobalVar.get_stop_gc() is True:
                return f"alpha_scan() stopped by user"

            await send_gc_msg(f"Scanning Iteration: {i + 1}")
            results = await alpha_scan_operation(exc_window_count, det_window_count, window, include_ext, hv_gating)
            file.write(f"Scan #{(i+1):02d} ")
            for j in range(len(results)):
                file.write(f" ; {results[j]:8d}")
            file.write(f"\n")
            await asyncio.sleep(delay)

    return f"alpha_scan() done"


async def alpha_scan_operation(exc_window_count, det_window_count, window, include_ext, hv_gating):
    if (exc_window_count < 1) or (exc_window_count > 65536):
        raise ValueError(f"exc_window_count must be in the range [1, 65536]")
    if (det_window_count < 1) or (det_window_count > 65536):
        raise ValueError(f"det_window_count must be in the range [1, 65536]")
    if (window < 1) or (window > 67108864):
        raise ValueError(f"window must be in the range [1, 67108864]")

    result_count = det_window_count if not include_ext else (exc_window_count + det_window_count)
    if (result_count > 4096):
        raise ValueError(f"result_count must be less or equal to 4096")

    hv_gate_delay = 1000    # 10 us

    op_id = 'alpha_scan'
    seq_gen = meas_seq_generator()

    # High voltage gate on if no HV gating
    seq_gen.TimerWaitAndRestart(hv_gate_delay)
    if hv_gating:
        seq_gen.ResetSignals(OutputSignal.HVGatePMT1)
    else:
        seq_gen.SetSignals(OutputSignal.HVGatePMT1)
    
    # Clear the result buffer
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    seq_gen.Loop(result_count)
    seq_gen.ClearResultBuffer(relative=True, dword=False, addrReg=0, addr=0)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()

    # Reset address register 0 for the result offset
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)

    # Excitation Phase
    seq_gen.TimerWait()
    seq_gen.ResetSignals(OutputSignal.Alpha)
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.Loop(exc_window_count)
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=False, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    if include_ext:
        seq_gen.GetPulseCounterResult(MeasurementChannel.PMT1, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=0)
        seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()

    # Detection Phase
    seq_gen.SetSignals(OutputSignal.Alpha | OutputSignal.HVGatePMT1)
    seq_gen.Loop(det_window_count)
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=False, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT1, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=0)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()

    seq_gen.ResetSignals(OutputSignal.HVGatePMT1)
    seq_gen.Stop(0)
    
    meas_unit.ClearOperations()
    meas_unit.resultAddresses[op_id] = range(0, result_count)
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)
    await meas_unit.ExecuteMeasurement(op_id)
    results = await meas_unit.ReadMeasurementValues(op_id)

    return results

